import os, time, cv2, paddle
import numpy as np
from matplotlib import pyplot as plt
from src.model.ssyolo import SSYOLO

class Inferer():
    def __init__(self, 
                 num_classes  =4,
                 model_param  ='./out/model.pdparams',
                 label_path   ='./data/label.txt',
                 output_dir   ='./out/',
                 scale_size   =608,
                 interpolation=cv2.INTER_LINEAR,
                 mean         =[0.485, 0.456, 0.406],
                 stdv         =[0.229, 0.224, 0.225],
                 bbox_thick   =1,
                 font_thick   =1,
                 font_scale   =0.4):
        """
        初始化推理器
        params:
        - num_classes  : 物体类别数量
        - model_param  : 模型参数路径
        - label_path   : 物体标签路径
        - output_dir   : 输出结果目录
        - scale_size   : 图像缩放大小
        - interpolation: 缩放插值方法
        - mean         : 数据通道均值
        - stdv         : 数据通道方差
        - bbox_thick   : 物体边框粗细
        - font_thick   : 文本边框粗细
        - font_scale   : 文本字体大小
        """
        # 设置变量
        self.label_path = label_path           # 物体标签路径
        self.label_list = self.get_labellist() # 获取标签列表
        self.color_list = self.get_colorlist() # 获取颜色列表
        self.output_dir = output_dir           # 输出结果目录
        
        self.scale_size = scale_size           # 图像缩放大小
        self.interpolation = interpolation     # 缩放插值方法
        self.mean = mean                       # 数据通道均值
        self.stdv = stdv                       # 数据通道方差
        self.bbox_thick = bbox_thick           # 物体边框粗细
        self.font_thick = font_thick           # 文本边框粗细
        self.font_scale = font_scale           # 文本字体大小
        
        # 声明模型
        self.model = SSYOLO(num_classes)       # 声明网络模型
        self.model.eval()                      # 设置验证模式
        
        # 加载权重
        if self.model is not None and os.path.exists(model_param): # 是否加载权重
            model_state_dict = paddle.load(model_param)            # 加载模型权重
            self.model.set_state_dict(model_state_dict)            # 设置模型权重
        else:
            print('警告：模型权重加载失败！')
            
    def get_labellist(self):
        """
        获取标签列表
        return:
        - label_list: 标签列表
        """
        assert os.path.exists(self.label_path), '错误：标签文件不存在！' # 检测标签文件
        
        label_list = [] # 标签列表
        with open(self.label_path, 'r') as f:    # 打开标签文件
            for label in f.readlines():          # 读取标签名称
                label_list.append(label.strip()) # 添加标签列表
    
        return label_list
    
    def get_colorlist(self):
        """
        获取颜色列表
        return:
        - color_list: 颜色列表, (80,3)
        """
        color_list = np.array([
            0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494,
            0.184, 0.556, 0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078,
            0.184, 0.300, 0.300, 0.300, 0.600, 0.600, 0.600, 1.000, 0.000, 0.000,
            1.000, 0.500, 0.000, 0.749, 0.749, 0.000, 0.000, 1.000, 0.000, 0.000,
            0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000, 0.333, 0.667,
            0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000,
            0.667, 1.000, 0.000, 1.000, 0.333, 0.000, 1.000, 0.667, 0.000, 1.000,
            1.000, 0.000, 0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000,
            0.500, 0.333, 0.000, 0.500, 0.333, 0.333, 0.500, 0.333, 0.667, 0.500,
            0.333, 1.000, 0.500, 0.667, 0.000, 0.500, 0.667, 0.333, 0.500, 0.667,
            0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500, 1.000, 0.333,
            0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000,
            0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333,
            0.333, 1.000, 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000,
            1.000, 0.667, 0.333, 1.000, 0.667, 0.667, 1.000, 0.667, 1.000, 1.000,
            1.000, 0.000, 1.000, 1.000, 0.333, 1.000, 1.000, 0.667, 1.000, 0.167,
            0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000,
            0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000,
            0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000,
            0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000,
            0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833,
            0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.143, 0.143, 0.143, 0.286,
            0.286, 0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571, 0.714, 0.714,
            0.714, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000, 0.000, 0.000, 0.000
        ]).astype(np.float32)                            # 声明颜色列表
        color_list = color_list.reshape((-1, 3)) * 255   # 转换颜色范围,[0,255]
        color_list = color_list.astype('int32').tolist() # 转换列表格式
        
        return color_list
    
    def infer(self, image_path):
        """
        进行图像推理
        params:
        - image_path: 图像路径
        """
        # 读取图像
        infer_time = time.time()                       # 开始时间
        images, imghws = self.read_image(image_path)   # 读取图像
        
        # 前向传播
        p_list = self.model(images)                    # 前向传播
        
        # 计算预测
        infers = self.model.get_pred(p_list, imghws)   # 计算预测
        infer_time = (time.time() - infer_time) * 1000 # 结束时间
        
        # 保存结果
        print(f'infer time: {infer_time:.3f} ms')      # 打印时间
        self.save_image(image_path, infers)            # 保存结果
        
    def read_image(self, image_path):
        """
        读取推理图像
        params:
        - image_path: 图像路径
        return:
        - images    : 图像数据
        - imghws    : 图像高宽
        """
        # 读取图像
        assert os.path.exists(image_path), '错误：图像文件不存在！' # 检测图像文件
        with open(image_path, 'rb') as f:                       # 打开图像文件
            image = f.read()                                    # 读取图像数据
        image = np.frombuffer(image, dtype='uint8')             # 读到数组缓存
        image = cv2.imdecode(image, 1)                          # 解码图片通道
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)          # 转换图片通道
        
        # 读取高宽
        img_h = float(image.shape[0])                     # 图像高度
        img_w = float(image.shape[1])                     # 图像宽度
        imghw = np.array([img_h, img_w], dtype='float32') # 图像高宽
        
        # 缩放图像
        image = cv2.resize(image, (self.scale_size, self.scale_size), self.interpolation)
        
        # 归一图像
        image = image.astype('float32', copy=False)                             # 转换数据格式
        mean  = np.array(self.mean, dtype='float32')[np.newaxis, np.newaxis, :] # 生成均值矩阵
        stdv  = np.array(self.stdv, dtype='float32')[np.newaxis, np.newaxis, :] # 生成方差矩阵
        image = ((image/255.0) - mean) / stdv                                   # 归一化[0,1]
        
        # 变换通道
        image = image.transpose((2, 0, 1)) # 图像通道从HWC变换为CHW
        
        # 增加维度
        images = paddle.to_tensor(np.expand_dims(image, axis=0)) # 增加维度并转换为Tensor，[1,3,608,608]
        imghws = paddle.to_tensor(np.expand_dims(imghw, axis=0)) # 增加维度并转换为Tensor，[1,2]
        
        return images, imghws
    
    def save_image(self, image_path, infers):
        """
        保存推理图像
        params:
        - image_path: 图像路径
        - infers    : 预测结果
        """
        # 读取图像
        assert os.path.exists(image_path), '错误：图像文件不存在！' # 检测图片文件
        with open(image_path, 'rb') as f:                       # 打开图像文件
            image = f.read()                                    # 读取图像数据
        image = np.frombuffer(image, dtype='uint8')             # 读到数组缓存
        image = cv2.imdecode(image, 1)                          # 解码图片通道
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)          # 转换图片通道
        
        # 保存图像
        for infer in infers:                                                          # 遍历批次
            # 绘制边框
            for item in infer:                                                        # 遍历物体
                # 获取物体
                label = self.label_list[int(item[0])]                                 # 标签名称
                score = item[1]                                                       # 预测得分
                x1, y1, x2, y2 = item[2:6].astype('int32')                            # 边框坐标
                
                # 绘制边框
                bbox_color = self.color_list[int(item[0])]                            # 边框颜色
                cv2.rectangle(image, (x1, y1), (x2, y2), bbox_color, self.bbox_thick) # 绘制边框
                
                # 绘制标签
                text       = f'{label}: {score:.2f}'                                  # 标签内容
                font_face  = cv2.FONT_HERSHEY_SIMPLEX                                 # 字体样式
                font_color = [255, 255, 255]                                          # 字体颜色
                rect_color = self.color_list[int(item[0])]                            # 边框颜色
                rect_thick = -1                                                       # 边框粗细
                
                text_size, base_line = cv2.getTextSize(text, font_face, self.font_scale, self.font_thick)             # 文本尺寸
                cv2.rectangle(image, (x1, y1-base_line-text_size[1]), (x1+text_size[0], y1), rect_color, rect_thick)  # 文本边框
                cv2.putText(image, text, (x1, y1-base_line), font_face, self.font_scale, font_color, self.font_thick) # 文本内容
            
            # 显示图像
            plt.figure(figsize=(6, 6)) # 创建图表
            plt.imshow(image)          # 显示图像
            plt.axis('off')            # 关闭坐标
            plt.tight_layout()         # 紧缩布局
            plt.show()                 # 显示图表
                
            # 保存图像
            if not os.path.exists(self.output_dir):                                  # 是否存在目录
                os.makedirs(self.output_dir)                                         # 创建保存目录
            save_path = os.path.join(self.output_dir, os.path.split(image_path)[-1]) # 设置保存路径
            image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)                           # 转换图片格式
            cv2.imwrite(save_path, image)                                            # 保存图像文件

if __name__ == "__main__":
    # 实例化推理器
    inferer = Inferer(
        num_classes=4,                      # 物体类别数量
        model_param='./out/model.pdparams', # 模型参数路径
        label_path ='./data/label.txt',     # 物体标签路径
        output_dir ='./out/',               # 输出结果目录
        scale_size =608,                    # 图像缩放大小
        bbox_thick =1,                      # 物体边框粗细
        font_thick =1,                      # 文本边框粗细
        font_scale =0.4                     # 文本字体大小
    )

    # 启动推理图像
    inferer.infer(image_path='./data/images/road554.png')